{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "senior-empire",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "latest-pastor",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "from tqdm import tqdm_notebook"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecological-greensboro",
   "metadata": {},
   "source": [
    "## Load Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ahead-calcium",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"data/AAAI2023Competition/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "hydraulic-heading",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train_valid = pd.read_csv(os.path.join(data_dir,\"train_valid_sequences.csv\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "quick-landscape",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fold</th>\n",
       "      <th>uid</th>\n",
       "      <th>questions</th>\n",
       "      <th>concepts</th>\n",
       "      <th>responses</th>\n",
       "      <th>timestamps</th>\n",
       "      <th>selectmasks</th>\n",
       "      <th>is_repeat</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>11066</td>\n",
       "      <td>3751,3752,3753,3754,1990,3739,3740,3742,3756,3...</td>\n",
       "      <td>187,187,374,187,374,188,188,228,166,170,221,40...</td>\n",
       "      <td>1,1,1,0,1,1,1,1,0,0,0,1,1,1,1,1,1,1,1,1,0,1,1,...</td>\n",
       "      <td>1595229836000,1595233013000,1595233687000,1595...</td>\n",
       "      <td>1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...</td>\n",
       "      <td>0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   fold    uid                                          questions  \\\n",
       "0     0  11066  3751,3752,3753,3754,1990,3739,3740,3742,3756,3...   \n",
       "\n",
       "                                            concepts  \\\n",
       "0  187,187,374,187,374,188,188,228,166,170,221,40...   \n",
       "\n",
       "                                           responses  \\\n",
       "0  1,1,1,0,1,1,1,1,0,0,0,1,1,1,1,1,1,1,1,1,0,1,1,...   \n",
       "\n",
       "                                          timestamps  \\\n",
       "0  1595229836000,1595233013000,1595233687000,1595...   \n",
       "\n",
       "                                         selectmasks  \\\n",
       "0  1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...   \n",
       "\n",
       "                                           is_repeat  \n",
       "0  0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,...  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train_valid.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "czech-tunnel",
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_fold = 0 #use the 0 fold to train and valid\n",
    "df_train = df_train_valid[df_train_valid['fold']!=valid_fold].copy()#train dataset for fold 0\n",
    "df_valid = df_train_valid[df_train_valid['fold']==valid_fold].copy()#valid dataset for fold 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "prostate-onion",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((26728, 8), (6669, 8))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.shape,df_valid.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "false-recall",
   "metadata": {},
   "outputs": [],
   "source": [
    "def flatten_dataset(df):\n",
    "    interaction_list = []\n",
    "    for _, row in tqdm_notebook(df.iterrows()):\n",
    "        uid = row['uid']\n",
    "        for question, concept, response in zip(row['questions'].split(\",\"),\n",
    "                                               row['concepts'].split(\",\"),\n",
    "                                               row['responses'].split(\",\")):\n",
    "            if response == \"-1\":#remove the padding\n",
    "                break\n",
    "            interaction = {\"uid\": int(uid),\n",
    "                           \"question\": int(question),\n",
    "                           \"concept\": int(concept),\n",
    "                           \"response\": int(response)}\n",
    "            interaction_list.append(interaction)\n",
    "    df_interaction = pd.DataFrame(interaction_list)\n",
    "    print(f\"# interaction is {len(interaction_list)}\")\n",
    "    return df_interaction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "equal-improvement",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6bd20c19f673432494c11e587aefe5b4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# interaction is 4109190\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "81e7941dec5f4f06a9235e595f393033",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# interaction is 1029854\n"
     ]
    }
   ],
   "source": [
    "df_train_inter = flatten_dataset(df_train)\n",
    "df_valid_inter = flatten_dataset(df_valid)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "documented-force",
   "metadata": {},
   "source": [
    "## Train a Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "appreciated-tractor",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>uid</th>\n",
       "      <th>question</th>\n",
       "      <th>concept</th>\n",
       "      <th>response</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>11779</td>\n",
       "      <td>3934</td>\n",
       "      <td>243</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     uid  question  concept  response\n",
       "0  11779      3934      243         1"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train_inter.head(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "provincial-daisy",
   "metadata": {},
   "source": [
    "### Question Level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "deluxe-amendment",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_answers_que_level = df_train_inter.groupby(\n",
    "    'question')['response'].agg(\"mean\")\n",
    "top_answers_dict_que_level = top_answers_que_level.to_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "static-spider",
   "metadata": {},
   "source": [
    "### Concept Level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "complicated-cricket",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_answers_concept_level = df_train_inter.groupby('concept')[\n",
    "    'response'].agg(\"mean\")\n",
    "top_answers_dict_concept_level = top_answers_concept_level.to_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "pleased-infrared",
   "metadata": {},
   "source": [
    "## Predict on the Validation Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "golden-assembly",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_valid_inter['que_pred'] = df_valid_inter['question'].apply(\n",
    "    lambda x: top_answers_dict_que_level.get(x, 0))\n",
    "df_valid_inter['concept_pred'] = df_valid_inter['concept'].apply(\n",
    "    lambda x: top_answers_dict_concept_level.get(x, 0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "convertible-catch",
   "metadata": {},
   "source": [
    "## Evaluate on the Validation Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "silent-angle",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_auc_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "alleged-worry",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7323341119020579"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "que_auc = roc_auc_score(df_valid_inter['response'],df_valid_inter['que_pred'])\n",
    "que_auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "antique-convention",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.619159927619923"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concept_auc = roc_auc_score(df_valid_inter['response'],df_valid_inter['concept_pred'])\n",
    "concept_auc"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "patent-mediterranean",
   "metadata": {},
   "source": [
    "## Train use the Whole Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "pleasant-surrey",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_answers_que_level = pd.concat([df_train_inter,df_valid_inter]).groupby(\n",
    "    'question')['response'].agg(\"mean\")\n",
    "top_answers_dict_que_level = top_answers_que_level.to_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "modern-hybrid",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.735416364022037"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check again\n",
    "df_valid_inter['que_pred'] = df_valid_inter['question'].apply(\n",
    "    lambda x: top_answers_dict_que_level.get(x, 0))\n",
    "que_auc = roc_auc_score(df_valid_inter['response'],df_valid_inter['que_pred'])\n",
    "que_auc"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "local-claim",
   "metadata": {},
   "source": [
    "## Predict on the Test Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "southern-twelve",
   "metadata": {},
   "source": [
    "From the experiement, we notice the question level model has better performance, we use the question level model as our final model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "sharp-hotel",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test = pd.read_csv(os.path.join(data_dir,\"pykt_test.csv\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "documentary-bacteria",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>uid</th>\n",
       "      <th>questions</th>\n",
       "      <th>concepts</th>\n",
       "      <th>responses</th>\n",
       "      <th>timestamps</th>\n",
       "      <th>is_repeat</th>\n",
       "      <th>num_test</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>8572</td>\n",
       "      <td>2203,268,266,271,270,269,2204,274,277,2206,220...</td>\n",
       "      <td>140,139,9,144,143,142,119,119,62,65,62,65,148,...</td>\n",
       "      <td>1,1,1,0,1,1,1,1,0,0,1,1,1,0,1,1,1,1,0,1,1,0,1,...</td>\n",
       "      <td>1594547742000,1594547742000,1594547742000,1594...</td>\n",
       "      <td>0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,...</td>\n",
       "      <td>186</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>179</td>\n",
       "      <td>3436,3437,3438,5243,2335,292,293,297,3440,1950...</td>\n",
       "      <td>155,155,155,55,55,155,155,55,117,307,55,307,55...</td>\n",
       "      <td>1,1,1,1,1,0,0,1,1,1,1,1,1,0,0,1,1,1,0,1,1,1,0,...</td>\n",
       "      <td>1599133602000,1599133602000,1599133602000,1599...</td>\n",
       "      <td>0,0,0,0,0,0,0,0,0,0,1,0,1,0,1,0,0,0,0,0,0,0,0,...</td>\n",
       "      <td>138</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>5664</td>\n",
       "      <td>1367,2504,983,4358,4356,4356,1370,4355,4357,43...</td>\n",
       "      <td>187,188,187,335,188,471,335,188,188,335,188,18...</td>\n",
       "      <td>1,1,0,1,1,1,1,1,1,0,1,0,1,0,1,1,1,1,0,0,1,1,1,...</td>\n",
       "      <td>1595219073000,1595244223000,1595244223000,1595...</td>\n",
       "      <td>0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...</td>\n",
       "      <td>199</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3587</td>\n",
       "      <td>268,266,267,270,271,269,1937,2323,2322,812,811...</td>\n",
       "      <td>139,9,141,143,144,142,232,231,232,30,30,30,147...</td>\n",
       "      <td>1,1,0,1,1,0,1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...</td>\n",
       "      <td>1597571706000,1597571706000,1597571706000,1597...</td>\n",
       "      <td>0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...</td>\n",
       "      <td>163</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>6711</td>\n",
       "      <td>2208,795,794,2333,2334,791,1945,1138,292,1946,...</td>\n",
       "      <td>155,304,303,155,155,155,155,155,155,155,155,10...</td>\n",
       "      <td>1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,...</td>\n",
       "      <td>1599216384000,1599290697000,1599290697000,1599...</td>\n",
       "      <td>0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,1,0,0,0,...</td>\n",
       "      <td>270</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3608</th>\n",
       "      <td>601</td>\n",
       "      <td>3293,3294,3292,5373,5372,5374,1130,3292,3293,3...</td>\n",
       "      <td>587,588,587,66,9,11,366,587,587,588,17,714,66,...</td>\n",
       "      <td>1,1,1,1,0,1,1,1,1,1,1,0,1,1,1,0,0,1,1,0,1,1,1,...</td>\n",
       "      <td>1594458929000,1594458929000,1594458929000,1594...</td>\n",
       "      <td>0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,1,0,1,0,0,0,1,...</td>\n",
       "      <td>129</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3609</th>\n",
       "      <td>2540</td>\n",
       "      <td>2309,2313,2318,2318,2328,2328,862,3268,3031,86...</td>\n",
       "      <td>365,142,231,232,497,488,223,144,144,142,140,9,...</td>\n",
       "      <td>1,0,1,1,0,0,1,0,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,...</td>\n",
       "      <td>1594446365000,1594446676000,1594447133000,1594...</td>\n",
       "      <td>0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,...</td>\n",
       "      <td>226</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3610</th>\n",
       "      <td>14945</td>\n",
       "      <td>3421,3419,3420,2202,1130,1129,3263,2310,2310,2...</td>\n",
       "      <td>140,502,365,365,366,365,139,66,9,140,7,11,9,6,...</td>\n",
       "      <td>1,0,1,1,1,1,0,1,1,1,1,1,1,1,0,1,1,1,1,0,1,1,0,...</td>\n",
       "      <td>1595073081000,1595073081000,1595073081000,1595...</td>\n",
       "      <td>0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,...</td>\n",
       "      <td>122</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3611</th>\n",
       "      <td>9105</td>\n",
       "      <td>1985,451,452,1986,1987,1989,1988,1990,453,454,...</td>\n",
       "      <td>475,138,187,188,188,374,188,374,228,228,214,21...</td>\n",
       "      <td>0,0,1,0,1,0,1,0,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,...</td>\n",
       "      <td>1594466792000,1594467267000,1594467816000,1594...</td>\n",
       "      <td>0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...</td>\n",
       "      <td>241</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3612</th>\n",
       "      <td>6075</td>\n",
       "      <td>3934,3935,3936,3937,3938,3939,3940,3941,3942,3...</td>\n",
       "      <td>243,5,243,243,243,2,2,244,244,367,313,367,313,...</td>\n",
       "      <td>1,1,1,1,0,1,0,1,1,1,1,1,0,1,1,1,1,1,0,1,1,1,0,...</td>\n",
       "      <td>1599301240000,1599301448000,1599301588000,1599...</td>\n",
       "      <td>0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...</td>\n",
       "      <td>211</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>3613 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        uid                                          questions  \\\n",
       "0      8572  2203,268,266,271,270,269,2204,274,277,2206,220...   \n",
       "1       179  3436,3437,3438,5243,2335,292,293,297,3440,1950...   \n",
       "2      5664  1367,2504,983,4358,4356,4356,1370,4355,4357,43...   \n",
       "3      3587  268,266,267,270,271,269,1937,2323,2322,812,811...   \n",
       "4      6711  2208,795,794,2333,2334,791,1945,1138,292,1946,...   \n",
       "...     ...                                                ...   \n",
       "3608    601  3293,3294,3292,5373,5372,5374,1130,3292,3293,3...   \n",
       "3609   2540  2309,2313,2318,2318,2328,2328,862,3268,3031,86...   \n",
       "3610  14945  3421,3419,3420,2202,1130,1129,3263,2310,2310,2...   \n",
       "3611   9105  1985,451,452,1986,1987,1989,1988,1990,453,454,...   \n",
       "3612   6075  3934,3935,3936,3937,3938,3939,3940,3941,3942,3...   \n",
       "\n",
       "                                               concepts  \\\n",
       "0     140,139,9,144,143,142,119,119,62,65,62,65,148,...   \n",
       "1     155,155,155,55,55,155,155,55,117,307,55,307,55...   \n",
       "2     187,188,187,335,188,471,335,188,188,335,188,18...   \n",
       "3     139,9,141,143,144,142,232,231,232,30,30,30,147...   \n",
       "4     155,304,303,155,155,155,155,155,155,155,155,10...   \n",
       "...                                                 ...   \n",
       "3608  587,588,587,66,9,11,366,587,587,588,17,714,66,...   \n",
       "3609  365,142,231,232,497,488,223,144,144,142,140,9,...   \n",
       "3610  140,502,365,365,366,365,139,66,9,140,7,11,9,6,...   \n",
       "3611  475,138,187,188,188,374,188,374,228,228,214,21...   \n",
       "3612  243,5,243,243,243,2,2,244,244,367,313,367,313,...   \n",
       "\n",
       "                                              responses  \\\n",
       "0     1,1,1,0,1,1,1,1,0,0,1,1,1,0,1,1,1,1,0,1,1,0,1,...   \n",
       "1     1,1,1,1,1,0,0,1,1,1,1,1,1,0,0,1,1,1,0,1,1,1,0,...   \n",
       "2     1,1,0,1,1,1,1,1,1,0,1,0,1,0,1,1,1,1,0,0,1,1,1,...   \n",
       "3     1,1,0,1,1,0,1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...   \n",
       "4     1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,...   \n",
       "...                                                 ...   \n",
       "3608  1,1,1,1,0,1,1,1,1,1,1,0,1,1,1,0,0,1,1,0,1,1,1,...   \n",
       "3609  1,0,1,1,0,0,1,0,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,...   \n",
       "3610  1,0,1,1,1,1,0,1,1,1,1,1,1,1,0,1,1,1,1,0,1,1,0,...   \n",
       "3611  0,0,1,0,1,0,1,0,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,...   \n",
       "3612  1,1,1,1,0,1,0,1,1,1,1,1,0,1,1,1,1,1,0,1,1,1,0,...   \n",
       "\n",
       "                                             timestamps  \\\n",
       "0     1594547742000,1594547742000,1594547742000,1594...   \n",
       "1     1599133602000,1599133602000,1599133602000,1599...   \n",
       "2     1595219073000,1595244223000,1595244223000,1595...   \n",
       "3     1597571706000,1597571706000,1597571706000,1597...   \n",
       "4     1599216384000,1599290697000,1599290697000,1599...   \n",
       "...                                                 ...   \n",
       "3608  1594458929000,1594458929000,1594458929000,1594...   \n",
       "3609  1594446365000,1594446676000,1594447133000,1594...   \n",
       "3610  1595073081000,1595073081000,1595073081000,1595...   \n",
       "3611  1594466792000,1594467267000,1594467816000,1594...   \n",
       "3612  1599301240000,1599301448000,1599301588000,1599...   \n",
       "\n",
       "                                              is_repeat  num_test  \n",
       "0     0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,...       186  \n",
       "1     0,0,0,0,0,0,0,0,0,0,1,0,1,0,1,0,0,0,0,0,0,0,0,...       138  \n",
       "2     0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...       199  \n",
       "3     0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...       163  \n",
       "4     0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,1,0,0,0,...       270  \n",
       "...                                                 ...       ...  \n",
       "3608  0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,1,0,1,0,0,0,1,...       129  \n",
       "3609  0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,...       226  \n",
       "3610  0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,...       122  \n",
       "3611  0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...       241  \n",
       "3612  0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...       211  \n",
       "\n",
       "[3613 rows x 7 columns]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "threaded-color",
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_str_list = []\n",
    "num_test = 0\n",
    "for _, row in df_test.iterrows():\n",
    "    predict_results = []\n",
    "    for question, response, is_repeat in zip(row['questions'].split(\",\"), \n",
    "                                             row['responses'].split(\",\"), \n",
    "                                             row['is_repeat'].split(\",\")):\n",
    "        question, response,is_repeat = int(question), int(response),int(is_repeat)\n",
    "        if is_repeat!=0:#skip the repeat\n",
    "            continue\n",
    "        if response == -1:\n",
    "            num_test += 1\n",
    "            predict_results.append(top_answers_dict_que_level.get(question, 0))\n",
    "    predict_str = \",\".join([str(x) for x in predict_results])\n",
    "    predict_str_list.append(predict_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "diverse-seattle",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "552290"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "regulation-spokesman",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_submit = pd.DataFrame({\"responses\":predict_str_list})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "surgical-vintage",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_submit.to_csv(\"prediction.csv\",index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "computational-farmer",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "4s_dev",
   "language": "python",
   "name": "4s_dev"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
